"""
 @Author: zzgsg
 @time: 2024/9/28 20:32
"""

# Figure 6. Profiles retrieved by USTC-PRM and OEM in CAMS.
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
import os
from warnings import filterwarnings
from scipy.interpolate import interp1d
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
import matplotlib.cm, matplotlib.colors
import sys

matplotlib.rcParams['mathtext.default'] = 'regular'


def data_load(datei):
    aero    = np.load('case/'+str(datei).replace('-', '') + '_aero.npy',            )
    aero_time   = np.load('case/'+str(datei).replace('-', '') + '_aero_time.npy',       )
    aero_bool   = np.load('case/'+str(datei).replace('-', '') + '_aero_bool.npy',       )
    aero_geom   = np.load('case/'+str(datei).replace('-', '') + '_aero_geom.npy',       )
    no2         = np.load('case/'+str(datei).replace('-', '') + '_no2.npy',             )*46/(6.02*10**11)
    no2_time    = np.load('case/'+str(datei).replace('-', '') + '_no2_time.npy',        )
    no2_bool    = np.load('case/'+str(datei).replace('-', '') + '_no2_bool.npy',        )
    no2_geom    = np.load('case/'+str(datei).replace('-', '') + '_no2_geom.npy',        )
    oe_aero     = np.load('case/'+str(datei).replace('-', '') + '_oe_aero.npy',         )
    oe_aero_time = np.load('case/'+str(datei).replace('-', '') + '_oe_aero_time.npy',    )
    oe_aero_bool = np.load('case/'+str(datei).replace('-', '') + '_oe_aero_bool.npy',    )
    oe_aero_geom = np.load('case/'+str(datei).replace('-', '') + '_oe_aero_geom.npy',    )
    oe_no2      = np.load('case/'+str(datei).replace('-', '') + '_oe_no2.npy',          )*46/(6.02*10**11)
    oe_no2_time = np.load('case/'+str(datei).replace('-', '') + '_oe_no2_time.npy',     )
    oe_no2_bool = np.load('case/'+str(datei).replace('-', '') + '_oe_no2_bool.npy',     )
    oe_no2_geom = np.load('case/'+str(datei).replace('-', '') + '_oe_no2_geom.npy',     )
    return aero,aero_time, aero_bool, aero_geom, no2, no2_time, no2_bool, no2_geom, \
           oe_aero, oe_aero_time, oe_aero_bool, oe_aero_geom, oe_no2, oe_no2_time, oe_no2_bool, oe_no2_geom

def plt_ax(ax,X,Z,label, xvis=None, title=None, vmax=1):
    if 'OE' in label:
        z_grid = np.r_[[0], np.arange(0.05, 1.01, 0.1), np.arange(1.1, 3.11, 0.2)]
    else:
        z_grid = np.r_[[0], np.arange(0.1, 4.01, 0.1)]
    X = ((X%1000000)//1E4)+((X%10000)//1E2)/60 + (X%100)/3600
    X = np.r_[X[0]-(X[1]-X[0])/2, (X[1:] + X[:-1])/2, X[-1]+((X[-1]-X[-2])/2)]
    ax.spines['bottom'].set_linewidth(4)
    ax.spines['left'].set_linewidth(4)
    ax.spines['right'].set_linewidth(4)
    ax.spines['top'].set_linewidth(4)
    ax.set_ylim([0, 3])
    ax.set_yticks([0, 1, 2, 3])
    ax.set_xlim([0, 8])
    plt.setp(ax.get_xticklabels(), fontsize=25)
    plt.setp(ax.get_yticklabels(), fontsize=25)
    if title is not None:
        ax.set_title(title, fontsize=25)
    if num==0:
        ax.set_ylabel('Height [km]', fontsize=25)
    else:
        ax.set_yticks([])
    if xvis is not None:
        ax.set_xlabel('Time (Local)', fontsize=25)
        ax.set_xticks([1,4,7])
        ax.set_xticklabels([9,12,15])
    else:
        ax.set_xticks([])
    ax.tick_params(axis='both', which='major', length=15, width=4)
    im = ax.pcolormesh(X, z_grid, Z, cmap='Spectral_r', shading='auto', vmin=0, vmax=vmax)

def bars(axs):
    cax = plt.axes([0.955, axs[1, 0].get_position().bounds[1], 0.01, axs[0, 0].get_position().bounds[3]+axs[0, 0].get_position().bounds[1]-axs[1, 0].get_position().bounds[1]])
    cbar = plt.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0, vmax=1), cmap='Spectral_r'),
                        cax=cax, shrink=0.6, pad=0, ticks=[0, 0.2, 0.4, 0.6, 0.8, 1], format='%.1f',
                        extend='both')
    cbar.ax.tick_params(labelsize=15)
    cbar.ax.tick_params(labelsize=15)
    cbar.set_label('Extinction @360nm [1/km]', fontsize=20)
    cax = plt.axes([0.955, axs[3, 0].get_position().bounds[1], 0.01, axs[0, 0].get_position().bounds[3]+axs[0, 0].get_position().bounds[1]-axs[1, 0].get_position().bounds[1]])
    cbar = plt.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=0, vmax=40), cmap='Spectral_r'),
                        cax=cax, shrink=0.6, pad=0, ticks=[0, 10, 20, 30, 40], format='%d',
                        extend='both')
    cbar.ax.tick_params(labelsize=15)
    cbar.set_label('$NO_2$ concentration [μg/$m^{3}$]', fontsize=20)

    axs[0, 0].annotate('USTC-PRM\nAEROSOL', xy=(-0.5, 0.5), xycoords='axes fraction', rotation=90, fontsize=30,
                 horizontalalignment='center', verticalalignment='center')
    axs[1, 0].annotate('OEM\nAEROSOL', xy=(-0.5, 0.5), xycoords='axes fraction', rotation=90, fontsize=30,
                 horizontalalignment='center', verticalalignment='center')
    axs[2, 0].annotate('USTC-PRM\n$NO_2$', xy=(-0.5, 0.5), xycoords='axes fraction', rotation=90, fontsize=30,
                 horizontalalignment='center', verticalalignment='center')
    axs[3, 0].annotate('OEM\n$NO_2$', xy=(-0.5, 0.5), xycoords='axes fraction', rotation=90, fontsize=30,
                 horizontalalignment='center', verticalalignment='center')

def axs_deal(axs):
    global num
    for num, datei in enumerate(np.arange(begin, end, dtype='datetime64')):
        try:
            aero, aero_time, aero_bool, aero_geom, no2, no2_time, no2_bool, no2_geom, oe_aero, oe_aero_time, \
            oe_aero_bool, oe_aero_geom, oe_no2, oe_no2_time, oe_no2_bool, oe_no2_geom = data_load(datei)
            bool_picked = np.isin(np.round(aero_time//10000), booltime2, invert=True)
            aero.T[~(aero_bool&aero_geom&bool_picked)] = np.nan
            plt_ax(axs[0, num],aero_time, aero, 'OPPA Extinction @360nm [1/km]', vmax=1, title=str(datei))

            bool_picked = np.isin(np.round(oe_aero_time//10000), booltime2, invert=True)
            oe_aero.T[~(oe_aero_bool&oe_aero_geom&bool_picked)] = np.nan
            plt_ax(axs[1, num],oe_aero_time, oe_aero, 'OE Extinction @360nm [1/km]', vmax=1)

            bool_picked = np.isin(np.round(no2_time//10000), booltime2, invert=True)
            no2.T[~(aero_bool&no2_bool&no2_geom&bool_picked)] = np.nan
            plt_ax(axs[2, num],no2_time, no2, 'OPPA $NO_2$ [μg/$m^{3}$]', vmax=40)

            bool_picked = np.isin(np.round(oe_no2_time//10000), booltime2, invert=True)
            oe_no2.T[~(oe_no2_bool&oe_no2_geom&bool_picked)] = np.nan
            plt_ax(axs[3, num],oe_no2_time, oe_no2, 'OE $NO_2$ [μg/$m^{3}$]', vmax=40, xvis=1)
        except:
            continue

def get_cloud():
    cloud = np.loadtxt(r"545110-99999-2019")
    cloud_ = np.loadtxt(r"545110-99999-2020")
    cloud = np.r_[cloud, cloud_]
    cloud2 = cloud[:]*1
    # cloud2[1:-1][cloud2[1:-1][:, 9]<-1, 9]=np.max([cloud[:-2, 9][cloud2[1:-1][:, 9]<-1], cloud[2:, 9][cloud2[1:-1][:, 9]<-1]], axis=0)
    cloud2[:-2][cloud2[:-2][:, 9]<-1, 9]=cloud[1:-1, 9][cloud2[:-2][:, 9]<-1]
    cloud2[:-2][cloud2[:-2][:, 9]<-1, 9]=cloud[2:, 9][cloud2[:-2][:, 9]<-1]
    unpicked = cloud2[((cloud2[:, 9]>6.9)&(cloud2[:, 9]<7.1))]
    booltime2 = np.round(unpicked[:, 0]*1000000+unpicked[:, 1]*10000+unpicked[:, 2]*100++unpicked[:, 3]).astype(int)
    # bool_picked = np.isin(np.round(timeh), booltime2, invert=True)
    return booltime2

booltime2 = get_cloud()
begin = '2019-11-28'
end = '2019-12-04'
plt.rcParams['mathtext.default'] = 'default'
plt.rcParams['font.size'] = 12
fig, axs = plt.subplots(4,6, figsize=(23,14), dpi=90)
plt.subplots_adjust(top=0.95, bottom=0.07, right=0.94, left=0.1, hspace=0.1, wspace=0.1)
axs_deal(axs)
bars(axs)
plt.savefig(f'temp/Figure6.eps')
plt.savefig(f'temp/Figure6.png')
plt.close()
